import os
import numpy as np
import torch
from tqdm import tqdm
import cv2
import open3d as o3d
import trimesh
import torch.nn.functional as F
import torch.nn as nn
from pytorch3d.loss import mesh_laplacian_smoothing, mesh_normal_consistency
from pytorch3d.transforms import quaternion_apply, quaternion_invert
from pytorch3d.ops import knn_points
from np.np_train import Runner
from np.np_dataset import Dataset
from rich.console import Console
import time


def training_with_sdf_regularization(args):

    checkpoint_path = f'pullgs_fromgs7000'
    checkpoint_path = os.path.join(args.output_dir, checkpoint_path)

    nerfmodel = GaussianSplattingWrapper(
        source_path=source_path,
        output_path=gs_checkpoint_path,
        iteration_to_load=iteration_to_load,
        load_gt_images=True,
        eval_split=use_eval_split,
        eval_split_interval=n_skip_images_for_eval_split,
        dataset_name=args.dataset_name,
        image_resolution=args.image_resolution,
        )

    # ====================Initialize GSPull model====================
    gspull = GSPull(
        nerfmodel=nerfmodel,
        points=points, #nerfmodel.gaussians.get_xyz.data,
        colors=colors, #0.5 + _C0 * nerfmodel.gaussians.get_features.data[:, 0, :],
        initialize=True,)
    with torch.no_grad():
        CONSOLE.print("Initializing 3D gaussians from 3D gaussians...")
        gspull._scales[...] = nerfmodel.gaussians._scaling.detach()
        gspull._quaternions[...] = nerfmodel.gaussians._rotation.detach()
        gspull.all_densities[...] = nerfmodel.gaussians._opacity.detach()
        gspull._sh_coordinates_dc[...] = nerfmodel.gaussians._features_dc.detach()
        gspull._sh_coordinates_rest[...] = nerfmodel.gaussians._features_rest.detach()
    gspull.neus = Runner(checkpoint_path, None)


    # ====================Initialize optimizer====================

    opt_params = OptimizationParams(
        iterations=num_iterations,
        position_lr_init=position_lr_init,
        position_lr_final=position_lr_final,
        position_lr_delay_mult=position_lr_delay_mult,
        position_lr_max_steps=position_lr_max_steps,
        feature_lr=feature_lr,
        opacity_lr=opacity_lr,
        scaling_lr=scaling_lr,
        rotation_lr=rotation_lr,
    )
    optimizer = GSPullOptimizer(gspull, opt_params)



    # ====================Start training====================
    gspull.train()

    for batch in range(9_999_999):
        if iteration >= num_iterations:
            break
        # Shuffle images
        shuffled_idx = torch.randperm(len(nerfmodel.training_cameras))
        train_num_images = len(shuffled_idx)

        for i in range(0, train_num_images, train_num_images_per_batch):
            iteration += 1

            # Update learning rates
            sdfnet_lr = gspull.neus.get_learning_rate_at_iteration(_iter - start_iteration)
            optimizer.update_learning_rate(iteration, sdfnet_lr=sdfnet_lr)

            start_idx = i
            end_idx = min(i+train_num_images_per_batch, train_num_images)
            
            camera_indices = shuffled_idx[start_idx:end_idx]
            
            # Computing rgb predictions
            outputs = gspull.render_image_gaussian_rasterizer(
                camera_indices=camera_indices.item(),
                verbose=False,
                bg_color=torch.Tensor([1.0, 1.0, 1.0]) if args.white_bg else None,
                use_pulled=False,
                )
            pred_rgb = outputs['image'].view(-1,image_height,image_width,3)
            radii = outputs['radii']
            viewspace_points = outputs['viewspace_points']

            pred_rgb = pred_rgb.transpose(-1, -2).transpose(-2, -3)

            # Gather rgb ground truth
            gt_image = nerfmodel.get_gt_image(camera_indices=camera_indices)
            gt_rgb = gt_image.view(-1, gspull.image_height, gspull.image_width, 3)
            gt_rgb = gt_rgb.transpose(-1, -2).transpose(-2, -3)

            # Compute loss
            loss = loss_fn(pred_rgb, gt_rgb)
            render_loss = loss.item()

            # -----------------------------------------------
            # ours: scaling loss
            if train_scaling:
                clamped_scaling = torch.clamp(gspull.scaling.min(1)[0], min=1e-4)
                scaling_loss = torch.abs(clamped_scaling - 1e-4).mean()
                loss = loss + 100 * scaling_loss
            # ours: neuralpull loss
            if train_sdf:
                gspull_points = gspull.points
                dataset = getattr(gspull.neus, 'dataset')
                points, samples, point_gt, points_idx = dataset.get_train_data(10000)

                samples.requires_grad = True
                sdf_network = getattr(gspull.neus, 'sdf_network')
                gradients_sample = sdf_network.gradient(samples).squeeze()  # 5000x3
                sdf_sample = sdf_network.sdf(samples)  # 5000x1
                grad_norm = F.normalize(gradients_sample, dim=1)  # 5000x3
                sample_moved = samples - grad_norm * sdf_sample  # 5000x3

                sdf_loss1 = gspull.neus.ChamferDisL1(points.unsqueeze(0), sample_moved.unsqueeze(0))
                # pull to disk
                scaled_sample_moved = sample_moved * dataset.shape_scale + dataset.shape_center
                knn = knn_points(sample_moved[None], points[None], K=1)
                knn_idx = knn.idx[0,:,0]
                gaussian_inv_scaled_rotation = gspull.get_covariance(
                    return_full_matrix=True, return_sqrt=True, inverse_scales=True, scaling_factor=100, enlarge_minaxis=1)
                batch_selected_idx = torch.arange(gspull.points.shape[0])[dataset.downsample_idx][points_idx][knn_idx].cuda()
                closest_gaussian_inv_scaled_rotation = gaussian_inv_scaled_rotation[batch_selected_idx].detach()
                surf_points = gspull.points[batch_selected_idx].detach().clone()
                shift = (scaled_sample_moved - surf_points)
                warped_shift = closest_gaussian_inv_scaled_rotation.transpose(-1, -2) @ shift[..., None]
                neighbor_opacities = (warped_shift[..., 0] * warped_shift[..., 0]).sum(dim=-1).clamp(min=0., max=1e8)
                neighbor_opacities = torch.exp(-1. / 2 * neighbor_opacities)
                sdf_loss2 = -torch.log(neighbor_opacities)
                sdf_loss = 1.0 * sdf_loss2


                loss = loss + sdf_loss
            # ours: np norm consistency
            if train_normal:
                assert train_sdf, 'require train_sdf=True for train_normal!'
                if iteration > 10000:
                    before_moved_points = (gspull.points[batch_selected_idx] - dataset.shape_center) / dataset.shape_scale
                    positions_grad = sdf_network.gradient(before_moved_points).squeeze()
                    sdf = sdf_network.sdf(before_moved_points)
                    grad_norm = F.normalize(positions_grad, dim=1)
                    moved_points = before_moved_points - grad_norm * sdf

                    moved_points = moved_points
                    gspull_normals = gspull.get_normals()[batch_selected_idx]
                    pulled_normals = sdf_network.gradient(moved_points).squeeze()
                    pulled_normals = F.normalize(pulled_normals, dim=1)

                    normal_loss1 = torch.abs(torch.sum(pulled_normals * gspull_normals, -1).abs() - 1).mean()

                gaussian_center_normals = gspull.get_normals()[dataset.downsample_idx][points_idx][knn_idx]
                normal_loss2 = torch.abs(torch.sum(grad_norm * gaussian_center_normals, -1).abs() - 1).mean()
                normal_weight = 0.1

                loss = loss + normal_weight * (normal_loss1+normal_loss2)


            # Update parameters
            loss.backward()
            
            # ours: visualize and log
            # if iteration % 2 == 0:
            if _iter % 1 == 0 and render_loss != 0:
                logger.add_scalar('Loss/render_loss', render_loss, global_step=iteration)
                logger.add_scalar('Loss/opacity_loss', opacity_loss, global_step=iteration)
                logger.add_scalar('Loss/scaling_loss', scaling_loss, global_step=iteration)
                logger.add_scalar('Loss/sdf_loss', sdf_loss, global_step=iteration)
                logger.add_scalar('Loss/norm_loss', normal_loss, global_step=iteration)
                logger.add_scalar('Loss/repulsion_loss', repulsion_loss, global_step=iteration)
            # if _iter % 2000 == 0:
            if (iteration % 200 == 0 and iteration != last_visual_iteration and iteration > 9000) or iteration == num_iterations:
                with torch.no_grad():
                    vertices_list = []
                    triangles_list = []
                    evaluated_mesh = gspull.marching_cubes(iteration, gspull_checkpoint_path, vertex_color=True, thres=0.002)

                    gspull.visual_point_cloud(iteration, gspull_checkpoint_path)
                    gspull.validate_image(pred_rgb, camera_indices.item(), iteration, gspull_checkpoint_path, )
                    gspull.validate_normal_image(1, iteration, gspull_checkpoint_path)
                last_visual_iteration = iteration
                torch.cuda.empty_cache()


            # Optimization step
            optimizer.step()
            optimizer.zero_grad(set_to_none = True)
            
            # Print loss
            if _iter % 50 == 0:
                print(iteration, _iter, loss.item())

            # Save model
            if iteration % save_model_every_n_iterations == 0 and iteration != last_save_iteration:
                CONSOLE.print("Saving model...")
                model_path = os.path.join(gspull_checkpoint_path, f'{iteration}.pt')
                gspull.save_model(path=model_path,
                                train_losses=train_losses,
                                epoch=epoch,
                                iteration=iteration,
                                optimizer_state_dict=optimizer.state_dict(),
                                )
                gspull.save_ply(gspull_checkpoint_path, iteration)
                # if optimize_triangles and iteration >= optimize_triangles_from:
                #     rm.save_model(os.path.join(rc_checkpoint_path, f'rm_{iteration}.pt'))
                CONSOLE.print("Model saved.")
                if iteration > 9000:
                    gspull.neus.save_checkpoint(gspull_checkpoint_path, iteration)
                last_save_iteration = iteration

            if iteration >= num_iterations:
                break
        epoch += 1

    CONSOLE.print(f"Training finished after {num_iterations} iterations with loss={loss.detach().item()}.")
    CONSOLE.print("Saving final model...")
    model_path = os.path.join(gspull_checkpoint_path, f'{iteration}.pt')
    gspull.save_model(path=model_path,
                    train_losses=train_losses,
                    epoch=epoch,
                    iteration=iteration,
                    optimizer_state_dict=optimizer.state_dict(),
                    )

    CONSOLE.print("Final model saved.")
    return model_path